{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Chapter 19 – Training and Deploying TensorFlow Models at Scale**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_This notebook contains all the sample code in chapter 19._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/19_training_and_deploying_at_scale.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml2/blob/add-kaggle-badge/19_training_and_deploying_at_scale.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup\n",
    "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Python ≥3.5 is required\n",
    "import sys\n",
    "assert sys.version_info >= (3, 5)\n",
    "\n",
    "# Is this notebook running on Colab or Kaggle?\n",
    "IS_COLAB = \"google.colab\" in sys.modules\n",
    "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
    "\n",
    "if IS_COLAB or IS_KAGGLE:\n",
    "    !echo \"deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal\" > /etc/apt/sources.list.d/tensorflow-serving.list\n",
    "    !curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -\n",
    "    !apt update && apt-get install -y tensorflow-model-server\n",
    "    !pip install -q -U tensorflow-serving-api\n",
    "\n",
    "# Scikit-Learn ≥0.20 is required\n",
    "import sklearn\n",
    "assert sklearn.__version__ >= \"0.20\"\n",
    "\n",
    "# TensorFlow ≥2.0 is required\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "assert tf.__version__ >= \"2.0\"\n",
    "\n",
    "if not tf.config.list_physical_devices('GPU'):\n",
    "    print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
    "    if IS_COLAB:\n",
    "        print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
    "    if IS_KAGGLE:\n",
    "        print(\"Go to Settings > Accelerator and select GPU.\")\n",
    "\n",
    "# Common imports\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "# to make this notebook's output stable across runs\n",
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "# To plot pretty figures\n",
    "%matplotlib inline\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "mpl.rc('axes', labelsize=14)\n",
    "mpl.rc('xtick', labelsize=12)\n",
    "mpl.rc('ytick', labelsize=12)\n",
    "\n",
    "# Where to save the figures\n",
    "PROJECT_ROOT_DIR = \".\"\n",
    "CHAPTER_ID = \"deploy\"\n",
    "IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
    "os.makedirs(IMAGES_PATH, exist_ok=True)\n",
    "\n",
    "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
    "    path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
    "    print(\"Saving figure\", fig_id)\n",
    "    if tight_layout:\n",
    "        plt.tight_layout()\n",
    "    plt.savefig(path, format=fig_extension, dpi=resolution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploying TensorFlow models to TensorFlow Serving (TFS)\n",
    "We will use the REST API or the gRPC API."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save/Load a `SavedModel`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.mnist.load_data()\n",
    "X_train_full = X_train_full[..., np.newaxis].astype(np.float32) / 255.\n",
    "X_test = X_test[..., np.newaxis].astype(np.float32) / 255.\n",
    "X_valid, X_train = X_train_full[:5000], X_train_full[5000:]\n",
    "y_valid, y_train = y_train_full[:5000], y_train_full[5000:]\n",
    "X_new = X_test[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "1719/1719 [==============================] - 2s 1ms/step - loss: 1.1140 - accuracy: 0.7066 - val_loss: 0.3715 - val_accuracy: 0.9024\n",
      "Epoch 2/10\n",
      "1719/1719 [==============================] - 1s 713us/step - loss: 0.3695 - accuracy: 0.8981 - val_loss: 0.2990 - val_accuracy: 0.9144\n",
      "Epoch 3/10\n",
      "1719/1719 [==============================] - 1s 718us/step - loss: 0.3154 - accuracy: 0.9100 - val_loss: 0.2651 - val_accuracy: 0.9272\n",
      "Epoch 4/10\n",
      "1719/1719 [==============================] - 1s 706us/step - loss: 0.2765 - accuracy: 0.9223 - val_loss: 0.2436 - val_accuracy: 0.9334\n",
      "Epoch 5/10\n",
      "1719/1719 [==============================] - 1s 711us/step - loss: 0.2556 - accuracy: 0.9276 - val_loss: 0.2257 - val_accuracy: 0.9364\n",
      "Epoch 6/10\n",
      "1719/1719 [==============================] - 1s 715us/step - loss: 0.2367 - accuracy: 0.9321 - val_loss: 0.2121 - val_accuracy: 0.9396\n",
      "Epoch 7/10\n",
      "1719/1719 [==============================] - 1s 729us/step - loss: 0.2198 - accuracy: 0.9390 - val_loss: 0.1970 - val_accuracy: 0.9454\n",
      "Epoch 8/10\n",
      "1719/1719 [==============================] - 1s 716us/step - loss: 0.2057 - accuracy: 0.9425 - val_loss: 0.1880 - val_accuracy: 0.9476\n",
      "Epoch 9/10\n",
      "1719/1719 [==============================] - 1s 704us/step - loss: 0.1940 - accuracy: 0.9459 - val_loss: 0.1777 - val_accuracy: 0.9524\n",
      "Epoch 10/10\n",
      "1719/1719 [==============================] - 1s 711us/step - loss: 0.1798 - accuracy: 0.9482 - val_loss: 0.1684 - val_accuracy: 0.9546\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe3b8718590>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "model = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28, 1]),\n",
    "    keras.layers.Dense(100, activation=\"relu\"),\n",
    "    keras.layers.Dense(10, activation=\"softmax\")\n",
    "])\n",
    "model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "              optimizer=keras.optimizers.SGD(learning_rate=1e-2),\n",
    "              metrics=[\"accuracy\"])\n",
    "model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.97, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.round(model.predict(X_new), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'my_mnist_model/0001'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_version = \"0001\"\n",
    "model_name = \"my_mnist_model\"\n",
    "model_path = os.path.join(model_name, model_version)\n",
    "model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf {model_name}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: my_mnist_model/0001/assets\n"
     ]
    }
   ],
   "source": [
    "tf.saved_model.save(model, model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "my_mnist_model/\n",
      "    0001/\n",
      "        saved_model.pb\n",
      "        variables/\n",
      "            variables.data-00000-of-00001\n",
      "            variables.index\n",
      "        assets/\n"
     ]
    }
   ],
   "source": [
    "for root, dirs, files in os.walk(model_name):\n",
    "    indent = '    ' * root.count(os.sep)\n",
    "    print('{}{}/'.format(indent, os.path.basename(root)))\n",
    "    for filename in files:\n",
    "        print('{}{}'.format(indent + '    ', filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The given SavedModel contains the following tag-sets:\r\n",
      "'serve'\r\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir {model_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:\r\n",
      "SignatureDef key: \"__saved_model_init_op\"\r\n",
      "SignatureDef key: \"serving_default\"\r\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir {model_path} --tag_set serve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The given SavedModel SignatureDef contains the following input(s):\r\n",
      "  inputs['flatten_input'] tensor_info:\r\n",
      "      dtype: DT_FLOAT\r\n",
      "      shape: (-1, 28, 28, 1)\r\n",
      "      name: serving_default_flatten_input:0\r\n",
      "The given SavedModel SignatureDef contains the following output(s):\r\n",
      "  outputs['dense_1'] tensor_info:\r\n",
      "      dtype: DT_FLOAT\r\n",
      "      shape: (-1, 10)\r\n",
      "      name: StatefulPartitionedCall:0\r\n",
      "Method name is: tensorflow/serving/predict\r\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir {model_path} --tag_set serve \\\n",
    "                      --signature_def serving_default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
      "\n",
      "signature_def['__saved_model_init_op']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['__saved_model_init_op'] tensor_info:\n",
      "        dtype: DT_INVALID\n",
      "        shape: unknown_rank\n",
      "        name: NoOp\n",
      "  Method name is: \n",
      "\n",
      "signature_def['serving_default']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "    inputs['flatten_input'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 28, 28, 1)\n",
      "        name: serving_default_flatten_input:0\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['dense_1'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 10)\n",
      "        name: StatefulPartitionedCall:0\n",
      "  Method name is: tensorflow/serving/predict\n",
      "\n",
      "Defined Functions:\n",
      "  Function Name: '__call__'\n",
      "    Option #1\n",
      "      Callable with:\n",
      "        Argument #1\n",
      "          flatten_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='flatten_input')\n",
      "        Argument #2\n",
      "          DType: bool\n",
      "          Value: True\n",
      "        Argument #3\n",
      "<<45 more lines>>\n",
      "          DType: bool\n",
      "          Value: False\n",
      "        Argument #3\n",
      "          DType: NoneType\n",
      "          Value: None\n",
      "    Option #2\n",
      "      Callable with:\n",
      "        Argument #1\n",
      "          inputs: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='inputs')\n",
      "        Argument #2\n",
      "          DType: bool\n",
      "          Value: True\n",
      "        Argument #3\n",
      "          DType: NoneType\n",
      "          Value: None\n",
      "    Option #3\n",
      "      Callable with:\n",
      "        Argument #1\n",
      "          flatten_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='flatten_input')\n",
      "        Argument #2\n",
      "          DType: bool\n",
      "          Value: True\n",
      "        Argument #3\n",
      "          DType: NoneType\n",
      "          Value: None\n",
      "    Option #4\n",
      "      Callable with:\n",
      "        Argument #1\n",
      "          flatten_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='flatten_input')\n",
      "        Argument #2\n",
      "          DType: bool\n",
      "          Value: False\n",
      "        Argument #3\n",
      "          DType: NoneType\n",
      "          Value: None\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir {model_path} --all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's write the new instances to a `npy` file so we can pass them easily to our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"my_mnist_tests.npy\", X_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'flatten_input'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_name = model.input_names[0]\n",
    "input_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now let's use `saved_model_cli` to make predictions for the instances we just saved:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2021-02-18 22:15:30.294109: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set\n",
      "2021-02-18 22:15:30.294306: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "WARNING:tensorflow:From /Users/ageron/miniconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:445: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.\n",
      "INFO:tensorflow:Restoring parameters from my_mnist_model/0001/variables/variables\n",
      "2021-02-18 22:15:30.323498: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)\n",
      "Result for output key dense_1:\n",
      "[[1.1347984e-04 1.5187356e-07 9.7032893e-04 2.7640699e-03 3.7826971e-06\n",
      "  7.6876910e-05 3.9140293e-08 9.9559116e-01 5.3502394e-05 4.2665208e-04]\n",
      " [8.2443521e-04 3.5493889e-05 9.8826385e-01 7.0466995e-03 1.2957400e-07\n",
      "  2.3389691e-04 2.5639210e-03 9.5886099e-10 1.0314899e-03 8.7952529e-08]\n",
      " [4.4693781e-05 9.7028232e-01 9.0526715e-03 2.2641101e-03 4.8766597e-04\n",
      "  2.8800720e-03 2.2714981e-03 8.3753867e-03 4.0439744e-03 2.9759688e-04]]\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli run --dir {model_path} --tag_set serve \\\n",
    "                     --signature_def serving_default    \\\n",
    "                     --inputs {input_name}=my_mnist_tests.npy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.97, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.round([[1.1347984e-04, 1.5187356e-07, 9.7032893e-04, 2.7640699e-03, 3.7826971e-06,\n",
    "           7.6876910e-05, 3.9140293e-08, 9.9559116e-01, 5.3502394e-05, 4.2665208e-04],\n",
    "          [8.2443521e-04, 3.5493889e-05, 9.8826385e-01, 7.0466995e-03, 1.2957400e-07,\n",
    "           2.3389691e-04, 2.5639210e-03, 9.5886099e-10, 1.0314899e-03, 8.7952529e-08],\n",
    "          [4.4693781e-05, 9.7028232e-01, 9.0526715e-03, 2.2641101e-03, 4.8766597e-04,\n",
    "           2.8800720e-03, 2.2714981e-03, 8.3753867e-03, 4.0439744e-03, 2.9759688e-04]], 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorFlow Serving"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install [Docker](https://docs.docker.com/install/) if you don't have it already. Then run:\n",
    "\n",
    "```bash\n",
    "docker pull tensorflow/serving\n",
    "\n",
    "export ML_PATH=$HOME/ml # or wherever this project is\n",
    "docker run -it --rm -p 8500:8500 -p 8501:8501 \\\n",
    "   -v \"$ML_PATH/my_mnist_model:/models/my_mnist_model\" \\\n",
    "   -e MODEL_NAME=my_mnist_model \\\n",
    "   tensorflow/serving\n",
    "```\n",
    "Once you are finished using it, press Ctrl-C to shut down the server."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, if `tensorflow_model_server` is installed (e.g., if you are running this notebook in Colab), then the following 3 cells will start the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"MODEL_DIR\"] = os.path.split(os.path.abspath(model_path))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash --bg\n",
    "nohup tensorflow_model_server \\\n",
    "     --rest_api_port=8501 \\\n",
    "     --model_name=my_mnist_model \\\n",
    "     --model_base_path=\"${MODEL_DIR}\" >server.log 2>&1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2021-02-16 22:33:09.323538: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:93] Reading SavedModel debug info (if present) from: /models/my_mnist_model/0001\n",
      "2021-02-16 22:33:09.323642: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2021-02-16 22:33:09.360572: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.\n",
      "2021-02-16 22:33:09.361764: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2200000000 Hz\n",
      "2021-02-16 22:33:09.387713: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/my_mnist_model/0001\n",
      "2021-02-16 22:33:09.392739: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 71106 microseconds.\n",
      "2021-02-16 22:33:09.393390: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/my_mnist_model/0001/assets.extra/tf_serving_warmup_requests\n",
      "2021-02-16 22:33:09.393847: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: my_mnist_model version: 1}\n",
      "2021-02-16 22:33:09.398470: I tensorflow_serving/model_servers/server.cc:371] Running gRPC ModelServer at 0.0.0.0:8500 ...\n",
      "[warn] getaddrinfo: address family for nodename not supported\n",
      "2021-02-16 22:33:09.405622: I tensorflow_serving/model_servers/server.cc:391] Exporting HTTP/REST API at:localhost:8501 ...\n",
      "[evhttp_server.cc : 238] NET_LOG: Entering the event loop ...\n"
     ]
    }
   ],
   "source": [
    "!tail server.log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "input_data_json = json.dumps({\n",
    "    \"signature_name\": \"serving_default\",\n",
    "    \"instances\": X_new.tolist(),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\'{\"signature_name\": \"serving_default\", \"instances\": [[[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.32941177487373...'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "repr(input_data_json)[:1500] + \"...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's use TensorFlow Serving's REST API to make predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "SERVER_URL = 'http://localhost:8501/v1/models/my_mnist_model:predict'\n",
    "response = requests.post(SERVER_URL, data=input_data_json)\n",
    "response.raise_for_status() # raise an exception in case of error\n",
    "response = response.json()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['predictions'])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.97, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = np.array(response[\"predictions\"])\n",
    "y_proba.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the gRPC API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow_serving.apis.predict_pb2 import PredictRequest\n",
    "\n",
    "request = PredictRequest()\n",
    "request.model_spec.name = model_name\n",
    "request.model_spec.signature_name = \"serving_default\"\n",
    "input_name = model.input_names[0]\n",
    "request.inputs[input_name].CopyFrom(tf.make_tensor_proto(X_new))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import grpc\n",
    "from tensorflow_serving.apis import prediction_service_pb2_grpc\n",
    "\n",
    "channel = grpc.insecure_channel('localhost:8500')\n",
    "predict_service = prediction_service_pb2_grpc.PredictionServiceStub(channel)\n",
    "response = predict_service.Predict(request, timeout=10.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "outputs {\n",
       "  key: \"dense_1\"\n",
       "  value {\n",
       "    dtype: DT_FLOAT\n",
       "    tensor_shape {\n",
       "      dim {\n",
       "        size: 3\n",
       "      }\n",
       "      dim {\n",
       "        size: 10\n",
       "      }\n",
       "    }\n",
       "    float_val: 0.00011425172124290839\n",
       "    float_val: 1.513665068841874e-07\n",
       "    float_val: 0.0009818424005061388\n",
       "    float_val: 0.0027773496694862843\n",
       "    float_val: 3.758880893656169e-06\n",
       "    float_val: 7.6266449468676e-05\n",
       "    float_val: 3.9139514740327286e-08\n",
       "    float_val: 0.995561957359314\n",
       "    float_val: 5.344580131350085e-05\n",
       "    float_val: 0.00043088122038170695\n",
       "    float_val: 0.0008194865076802671\n",
       "    float_val: 3.5498320357874036e-05\n",
       "    float_val: 0.9882420897483826\n",
       "    float_val: 0.00705744931474328\n",
       "    float_val: 1.2937064752804872e-07\n",
       "    float_val: 0.00023402832448482513\n",
       "    float_val: 0.0025743397418409586\n",
       "    float_val: 9.668431610876382e-10\n",
       "    float_val: 0.0010369382798671722\n",
       "    float_val: 8.833576004008137e-08\n",
       "    float_val: 4.441547571332194e-05\n",
       "    float_val: 0.970328688621521\n",
       "    float_val: 0.009044423699378967\n",
       "    float_val: 0.0022599005606025457\n",
       "    float_val: 0.00048672096454538405\n",
       "    float_val: 0.002873610006645322\n",
       "    float_val: 0.002268279204145074\n",
       "    float_val: 0.008354829624295235\n",
       "    float_val: 0.004041312728077173\n",
       "    float_val: 0.0002978229313157499\n",
       "  }\n",
       "}\n",
       "model_spec {\n",
       "  name: \"my_mnist_model\"\n",
       "  version {\n",
       "    value: 1\n",
       "  }\n",
       "  signature_name: \"serving_default\"\n",
       "}"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert the response to a tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.97, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_name = model.output_names[0]\n",
    "outputs_proto = response.outputs[output_name]\n",
    "y_proba = tf.make_ndarray(outputs_proto)\n",
    "y_proba.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or to a NumPy array if your client does not include the TensorFlow library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.97, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_name = model.output_names[0]\n",
    "outputs_proto = response.outputs[output_name]\n",
    "shape = [dim.size for dim in outputs_proto.tensor_shape.dim]\n",
    "y_proba = np.array(outputs_proto.float_val).reshape(shape)\n",
    "y_proba.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploying a new model version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "1719/1719 [==============================] - 1s 748us/step - loss: 1.1567 - accuracy: 0.6691 - val_loss: 0.3418 - val_accuracy: 0.9042\n",
      "Epoch 2/10\n",
      "1719/1719 [==============================] - 1s 697us/step - loss: 0.3376 - accuracy: 0.9032 - val_loss: 0.2674 - val_accuracy: 0.9242\n",
      "Epoch 3/10\n",
      "1719/1719 [==============================] - 1s 676us/step - loss: 0.2779 - accuracy: 0.9187 - val_loss: 0.2227 - val_accuracy: 0.9368\n",
      "Epoch 4/10\n",
      "1719/1719 [==============================] - 1s 669us/step - loss: 0.2362 - accuracy: 0.9318 - val_loss: 0.2032 - val_accuracy: 0.9432\n",
      "Epoch 5/10\n",
      "1719/1719 [==============================] - 1s 670us/step - loss: 0.2109 - accuracy: 0.9389 - val_loss: 0.1833 - val_accuracy: 0.9482\n",
      "Epoch 6/10\n",
      "1719/1719 [==============================] - 1s 675us/step - loss: 0.1951 - accuracy: 0.9430 - val_loss: 0.1740 - val_accuracy: 0.9498\n",
      "Epoch 7/10\n",
      "1719/1719 [==============================] - 1s 667us/step - loss: 0.1799 - accuracy: 0.9474 - val_loss: 0.1605 - val_accuracy: 0.9540\n",
      "Epoch 8/10\n",
      "1719/1719 [==============================] - 1s 673us/step - loss: 0.1654 - accuracy: 0.9519 - val_loss: 0.1543 - val_accuracy: 0.9558\n",
      "Epoch 9/10\n",
      "1719/1719 [==============================] - 1s 671us/step - loss: 0.1570 - accuracy: 0.9554 - val_loss: 0.1460 - val_accuracy: 0.9572\n",
      "Epoch 10/10\n",
      "1719/1719 [==============================] - 1s 672us/step - loss: 0.1420 - accuracy: 0.9583 - val_loss: 0.1359 - val_accuracy: 0.9616\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(42)\n",
    "tf.random.set_seed(42)\n",
    "\n",
    "model = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28, 1]),\n",
    "    keras.layers.Dense(50, activation=\"relu\"),\n",
    "    keras.layers.Dense(50, activation=\"relu\"),\n",
    "    keras.layers.Dense(10, activation=\"softmax\")\n",
    "])\n",
    "model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "              optimizer=keras.optimizers.SGD(learning_rate=1e-2),\n",
    "              metrics=[\"accuracy\"])\n",
    "history = model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'my_mnist_model/0002'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_version = \"0002\"\n",
    "model_name = \"my_mnist_model\"\n",
    "model_path = os.path.join(model_name, model_version)\n",
    "model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: my_mnist_model/0002/assets\n"
     ]
    }
   ],
   "source": [
    "tf.saved_model.save(model, model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "my_mnist_model/\n",
      "    0001/\n",
      "        saved_model.pb\n",
      "        variables/\n",
      "            variables.data-00000-of-00001\n",
      "            variables.index\n",
      "        assets/\n",
      "    0002/\n",
      "        saved_model.pb\n",
      "        variables/\n",
      "            variables.data-00000-of-00001\n",
      "            variables.index\n",
      "        assets/\n"
     ]
    }
   ],
   "source": [
    "for root, dirs, files in os.walk(model_name):\n",
    "    indent = '    ' * root.count(os.sep)\n",
    "    print('{}{}/'.format(indent, os.path.basename(root)))\n",
    "    for filename in files:\n",
    "        print('{}{}'.format(indent + '    ', filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Warning**: You may need to wait a minute before the new model is loaded by TensorFlow Serving."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "SERVER_URL = 'http://localhost:8501/v1/models/my_mnist_model:predict'\n",
    "            \n",
    "response = requests.post(SERVER_URL, data=input_data_json)\n",
    "response.raise_for_status()\n",
    "response = response.json()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['predictions'])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_proba = np.array(response[\"predictions\"])\n",
    "y_proba.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy the model to Google Cloud AI Platform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow the instructions in the book to deploy the model to Google Cloud AI Platform, download the service account's private key and save it to the `my_service_account_private_key.json` in the project directory. Also, update the `project_id`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_id = \"onyx-smoke-242003\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import googleapiclient.discovery\n",
    "\n",
    "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"my_service_account_private_key.json\"\n",
    "model_id = \"my_mnist_model\"\n",
    "model_path = \"projects/{}/models/{}\".format(project_id, model_id)\n",
    "model_path += \"/versions/v0001/\" # if you want to run a specific version\n",
    "ml_resource = googleapiclient.discovery.build(\"ml\", \"v1\").projects()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(X):\n",
    "    input_data_json = {\"signature_name\": \"serving_default\",\n",
    "                       \"instances\": X.tolist()}\n",
    "    request = ml_resource.predict(name=model_path, body=input_data_json)\n",
    "    response = request.execute()\n",
    "    if \"error\" in response:\n",
    "        raise RuntimeError(response[\"error\"])\n",
    "    return np.array([pred[output_name] for pred in response[\"predictions\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],\n",
       "       [0.  , 0.99, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_probas = predict(X_new)\n",
    "np.round(Y_probas, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using GPUs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: `tf.test.is_gpu_available()` is deprecated. Instead, please use `tf.config.list_physical_devices('GPU')`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),\n",
       " PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#tf.test.is_gpu_available() # deprecated\n",
    "tf.config.list_physical_devices('GPU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/device:GPU:0'"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.test.gpu_device_name()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.test.is_built_with_cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[name: \"/device:CPU:0\"\n",
       " device_type: \"CPU\"\n",
       " memory_limit: 268435456\n",
       " locality {\n",
       " }\n",
       " incarnation: 7325002731160755624,\n",
       " name: \"/device:GPU:0\"\n",
       " device_type: \"GPU\"\n",
       " memory_limit: 11139884032\n",
       " locality {\n",
       "   bus_id: 1\n",
       "   links {\n",
       "     link {\n",
       "       device_id: 1\n",
       "       type: \"StreamExecutor\"\n",
       "       strength: 1\n",
       "     }\n",
       "   }\n",
       " }\n",
       " incarnation: 7150956550266107441\n",
       " physical_device_desc: \"device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7\",\n",
       " name: \"/device:GPU:1\"\n",
       " device_type: \"GPU\"\n",
       " memory_limit: 11139884032\n",
       " locality {\n",
       "   bus_id: 1\n",
       "   links {\n",
       "     link {\n",
       "       type: \"StreamExecutor\"\n",
       "       strength: 1\n",
       "     }\n",
       "   }\n",
       " }\n",
       " incarnation: 15909479382059415698\n",
       " physical_device_desc: \"device: 1, name: Tesla K80, pci bus id: 0000:00:05.0, compute capability: 3.7\"]"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow.python.client.device_lib import list_local_devices\n",
    "\n",
    "devices = list_local_devices()\n",
    "devices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras.backend.clear_session()\n",
    "tf.random.set_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_model():\n",
    "    return keras.models.Sequential([\n",
    "        keras.layers.Conv2D(filters=64, kernel_size=7, activation=\"relu\",\n",
    "                            padding=\"same\", input_shape=[28, 28, 1]),\n",
    "        keras.layers.MaxPooling2D(pool_size=2),\n",
    "        keras.layers.Conv2D(filters=128, kernel_size=3, activation=\"relu\",\n",
    "                            padding=\"same\"), \n",
    "        keras.layers.Conv2D(filters=128, kernel_size=3, activation=\"relu\",\n",
    "                            padding=\"same\"),\n",
    "        keras.layers.MaxPooling2D(pool_size=2),\n",
    "        keras.layers.Flatten(),\n",
    "        keras.layers.Dense(units=64, activation='relu'),\n",
    "        keras.layers.Dropout(0.5),\n",
    "        keras.layers.Dense(units=10, activation='softmax'),\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "550/550 [==============================] - 11s 18ms/step - loss: 1.8163 - accuracy: 0.3979 - val_loss: 0.3446 - val_accuracy: 0.9010\n",
      "Epoch 2/10\n",
      "550/550 [==============================] - 9s 17ms/step - loss: 0.4949 - accuracy: 0.8482 - val_loss: 0.1962 - val_accuracy: 0.9458\n",
      "Epoch 3/10\n",
      "550/550 [==============================] - 10s 17ms/step - loss: 0.3345 - accuracy: 0.9012 - val_loss: 0.1343 - val_accuracy: 0.9622\n",
      "Epoch 4/10\n",
      "550/550 [==============================] - 10s 17ms/step - loss: 0.2537 - accuracy: 0.9267 - val_loss: 0.1049 - val_accuracy: 0.9718\n",
      "Epoch 5/10\n",
      "550/550 [==============================] - 10s 17ms/step - loss: 0.2099 - accuracy: 0.9394 - val_loss: 0.0875 - val_accuracy: 0.9752\n",
      "Epoch 6/10\n",
      "550/550 [==============================] - 10s 17ms/step - loss: 0.1901 - accuracy: 0.9439 - val_loss: 0.0797 - val_accuracy: 0.9772\n",
      "Epoch 7/10\n",
      "550/550 [==============================] - 10s 18ms/step - loss: 0.1672 - accuracy: 0.9506 - val_loss: 0.0745 - val_accuracy: 0.9780\n",
      "Epoch 8/10\n",
      "550/550 [==============================] - 10s 18ms/step - loss: 0.1537 - accuracy: 0.9554 - val_loss: 0.0700 - val_accuracy: 0.9804\n",
      "Epoch 9/10\n",
      "550/550 [==============================] - 10s 18ms/step - loss: 0.1384 - accuracy: 0.9592 - val_loss: 0.0641 - val_accuracy: 0.9818\n",
      "Epoch 10/10\n",
      "550/550 [==============================] - 10s 18ms/step - loss: 0.1358 - accuracy: 0.9602 - val_loss: 0.0611 - val_accuracy: 0.9818\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f014831c110>"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 100\n",
    "model = create_model()\n",
    "model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "              optimizer=keras.optimizers.SGD(learning_rate=1e-2),\n",
    "              metrics=[\"accuracy\"])\n",
    "model.fit(X_train, y_train, epochs=10,\n",
    "          validation_data=(X_valid, y_valid), batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n"
     ]
    }
   ],
   "source": [
    "keras.backend.clear_session()\n",
    "tf.random.set_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "distribution = tf.distribute.MirroredStrategy()\n",
    "\n",
    "# Change the default all-reduce algorithm:\n",
    "#distribution = tf.distribute.MirroredStrategy(\n",
    "#    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())\n",
    "\n",
    "# Specify the list of GPUs to use:\n",
    "#distribution = tf.distribute.MirroredStrategy(devices=[\"/gpu:0\", \"/gpu:1\"])\n",
    "\n",
    "# Use the central storage strategy instead:\n",
    "#distribution = tf.distribute.experimental.CentralStorageStrategy()\n",
    "\n",
    "#if IS_COLAB and \"COLAB_TPU_ADDR\" in os.environ:\n",
    "#  tpu_address = \"grpc://\" + os.environ[\"COLAB_TPU_ADDR\"]\n",
    "#else:\n",
    "#  tpu_address = \"\"\n",
    "#resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)\n",
    "#tf.config.experimental_connect_to_cluster(resolver)\n",
    "#tf.tpu.experimental.initialize_tpu_system(resolver)\n",
    "#distribution = tf.distribute.experimental.TPUStrategy(resolver)\n",
    "\n",
    "with distribution.scope():\n",
    "    model = create_model()\n",
    "    model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "                  optimizer=keras.optimizers.SGD(learning_rate=1e-2),\n",
    "                  metrics=[\"accuracy\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "INFO:tensorflow:batch_all_reduce: 10 all-reduces with algorithm = nccl, num_packs = 1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:batch_all_reduce: 10 all-reduces with algorithm = nccl, num_packs = 1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "550/550 [==============================] - 14s 16ms/step - loss: 1.8193 - accuracy: 0.3957 - val_loss: 0.3366 - val_accuracy: 0.9102\n",
      "Epoch 2/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.4886 - accuracy: 0.8497 - val_loss: 0.1865 - val_accuracy: 0.9478\n",
      "Epoch 3/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.3305 - accuracy: 0.9008 - val_loss: 0.1344 - val_accuracy: 0.9616\n",
      "Epoch 4/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.2472 - accuracy: 0.9282 - val_loss: 0.1115 - val_accuracy: 0.9696\n",
      "Epoch 5/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.2020 - accuracy: 0.9425 - val_loss: 0.0873 - val_accuracy: 0.9748\n",
      "Epoch 6/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.1865 - accuracy: 0.9458 - val_loss: 0.0783 - val_accuracy: 0.9764\n",
      "Epoch 7/10\n",
      "550/550 [==============================] - 8s 14ms/step - loss: 0.1633 - accuracy: 0.9512 - val_loss: 0.0771 - val_accuracy: 0.9776\n",
      "Epoch 8/10\n",
      "550/550 [==============================] - 8s 14ms/step - loss: 0.1422 - accuracy: 0.9570 - val_loss: 0.0705 - val_accuracy: 0.9786\n",
      "Epoch 9/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.1408 - accuracy: 0.9603 - val_loss: 0.0627 - val_accuracy: 0.9830\n",
      "Epoch 10/10\n",
      "550/550 [==============================] - 7s 13ms/step - loss: 0.1293 - accuracy: 0.9618 - val_loss: 0.0605 - val_accuracy: 0.9836\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f259bf1a6d0>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 100 # must be divisible by the number of workers\n",
    "model.fit(X_train, y_train, epochs=10,\n",
    "          validation_data=(X_valid, y_valid), batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2.53707055e-10, 7.94509292e-10, 1.02021443e-06, 3.37102080e-08,\n",
       "        4.90816797e-11, 4.37713789e-11, 2.43314297e-14, 9.99996424e-01,\n",
       "        1.50591750e-09, 2.50736753e-06],\n",
       "       [1.11715025e-07, 8.56921833e-05, 9.99914169e-01, 6.31697228e-09,\n",
       "        3.99949344e-11, 4.47976906e-10, 8.46022008e-09, 3.03771834e-08,\n",
       "        2.91782563e-08, 1.95555502e-10],\n",
       "       [4.68117065e-07, 9.99787748e-01, 1.01387537e-04, 2.87393277e-06,\n",
       "        5.29725839e-05, 1.55926125e-06, 2.07211669e-05, 1.76809226e-05,\n",
       "        9.37155255e-06, 5.19965897e-06]], dtype=float32)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(X_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Custom training loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n",
      "WARNING:tensorflow:From <ipython-input-9-acb7c62c8738>:36: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the iterator's `initializer` property instead.\n",
      "Epoch 1/10\n",
      "INFO:tensorflow:batch_all_reduce: 10 all-reduces with algorithm = nccl, num_packs = 1\n",
      "INFO:tensorflow:batch_all_reduce: 10 all-reduces with algorithm = nccl, num_packs = 1\n",
      "Loss: 0.380\n",
      "Epoch 2/10\n",
      "Loss: 0.302\n",
      "Epoch 3/10\n",
      "Loss: 0.285\n",
      "Epoch 4/10\n",
      "Loss: 0.294\n",
      "Epoch 5/10\n",
      "Loss: 0.304\n",
      "Epoch 6/10\n",
      "Loss: 0.310\n",
      "Epoch 7/10\n",
      "Loss: 0.310\n",
      "Epoch 8/10\n",
      "Loss: 0.306\n",
      "Epoch 9/10\n",
      "Loss: 0.303\n",
      "Epoch 10/10\n",
      "Loss: 0.298\n"
     ]
    }
   ],
   "source": [
    "keras.backend.clear_session()\n",
    "tf.random.set_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "K = keras.backend\n",
    "\n",
    "distribution = tf.distribute.MirroredStrategy()\n",
    "\n",
    "with distribution.scope():\n",
    "    model = create_model()\n",
    "    optimizer = keras.optimizers.SGD()\n",
    "\n",
    "with distribution.scope():\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).repeat().batch(batch_size)\n",
    "    input_iterator = distribution.make_dataset_iterator(dataset)\n",
    "    \n",
    "@tf.function\n",
    "def train_step():\n",
    "    def step_fn(inputs):\n",
    "        X, y = inputs\n",
    "        with tf.GradientTape() as tape:\n",
    "            Y_proba = model(X)\n",
    "            loss = K.sum(keras.losses.sparse_categorical_crossentropy(y, Y_proba)) / batch_size\n",
    "\n",
    "        grads = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "        return loss\n",
    "\n",
    "    per_replica_losses = distribution.experimental_run(step_fn, input_iterator)\n",
    "    mean_loss = distribution.reduce(tf.distribute.ReduceOp.SUM,\n",
    "                                    per_replica_losses, axis=None)\n",
    "    return mean_loss\n",
    "\n",
    "n_epochs = 10\n",
    "with distribution.scope():\n",
    "    input_iterator.initialize()\n",
    "    for epoch in range(n_epochs):\n",
    "        print(\"Epoch {}/{}\".format(epoch + 1, n_epochs))\n",
    "        for iteration in range(len(X_train) // batch_size):\n",
    "            print(\"\\rLoss: {:.3f}\".format(train_step().numpy()), end=\"\")\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training across multiple servers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A TensorFlow cluster is a group of TensorFlow processes running in parallel, usually on different machines, and talking to each other to complete some work, for example training or executing a neural network. Each TF process in the cluster is called a \"task\" (or a \"TF server\"). It has an IP address, a port, and a type (also called its role or its job). The type can be `\"worker\"`, `\"chief\"`, `\"ps\"` (parameter server) or `\"evaluator\"`:\n",
    "* Each **worker** performs computations, usually on a machine with one or more GPUs.\n",
    "* The **chief** performs computations as well, but it also handles extra work such as writing TensorBoard logs or saving checkpoints. There is a single chief in a cluster, typically the first worker (i.e., worker #0).\n",
    "* A **parameter server** (ps) only keeps track of variable values, it is usually on a CPU-only machine.\n",
    "* The **evaluator** obviously takes care of evaluation. There is usually a single evaluator in a cluster.\n",
    "\n",
    "The set of tasks that share the same type is often called a \"job\". For example, the \"worker\" job is the set of all workers.\n",
    "\n",
    "To start a TensorFlow cluster, you must first define it. This means specifying all the tasks (IP address, TCP port, and type). For example, the following cluster specification defines a cluster with 3 tasks (2 workers and 1 parameter server). It's a dictionary with one key per job, and the values are lists of task addresses:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "cluster_spec = {\n",
    "    \"worker\": [\n",
    "        \"machine-a.example.com:2222\",  # /job:worker/task:0\n",
    "        \"machine-b.example.com:2222\"   # /job:worker/task:1\n",
    "    ],\n",
    "    \"ps\": [\"machine-c.example.com:2222\"] # /job:ps/task:0\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Every task in the cluster may communicate with every other task in the server, so make sure to configure your firewall to authorize all communications between these machines on these ports (it's usually simpler if you use the same port on every machine).\n",
    "\n",
    "When a task is started, it needs to be told which one it is: its type and index (the task index is also called the task id). A common way to specify everything at once (both the cluster spec and the current task's type and id) is to set the `TF_CONFIG` environment variable before starting the program. It must be a JSON-encoded dictionary containing a cluster specification (under the `\"cluster\"` key), and the type and index of the task to start (under the `\"task\"` key). For example, the following `TF_CONFIG` environment variable defines the same cluster as above, with 2 workers and 1 parameter server, and specifies that the task to start is worker #1:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"cluster\": {\"worker\": [\"machine-a.example.com:2222\", \"machine-b.example.com:2222\"], \"ps\": [\"machine-c.example.com:2222\"]}, \"task\": {\"type\": \"worker\", \"index\": 1}}'"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "os.environ[\"TF_CONFIG\"] = json.dumps({\n",
    "    \"cluster\": cluster_spec,\n",
    "    \"task\": {\"type\": \"worker\", \"index\": 1}\n",
    "})\n",
    "os.environ[\"TF_CONFIG\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some platforms (e.g., Google Cloud ML Engine) automatically set this environment variable for you."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TensorFlow's `TFConfigClusterResolver` class reads the cluster configuration from this environment variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ClusterSpec({'ps': ['machine-c.example.com:2222'], 'worker': ['machine-a.example.com:2222', 'machine-b.example.com:2222']})"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()\n",
    "resolver.cluster_spec()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'worker'"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resolver.task_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resolver.task_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's run a simpler cluster with just two worker tasks, both running on the local machine. We will use the `MultiWorkerMirroredStrategy` to train a model across these two tasks.\n",
    "\n",
    "The first step is to write the training code. As this code will be used to run both workers, each in its own process, we write this code to a separate Python file, `my_mnist_multiworker_task.py`. The code is relatively straightforward, but there are a couple important things to note:\n",
    "* We create the `MultiWorkerMirroredStrategy` before doing anything else with TensorFlow.\n",
    "* Only one of the workers will take care of logging to TensorBoard and saving checkpoints. As mentioned earlier, this worker is called the *chief*, and by convention it is usually worker #0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting my_mnist_multiworker_task.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile my_mnist_multiworker_task.py\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import time\n",
    "\n",
    "# At the beginning of the program\n",
    "distribution = tf.distribute.MultiWorkerMirroredStrategy()\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()\n",
    "print(\"Starting task {}{}\".format(resolver.task_type, resolver.task_id))\n",
    "\n",
    "# Only worker #0 will write checkpoints and log to TensorBoard\n",
    "if resolver.task_id == 0:\n",
    "    root_logdir = os.path.join(os.curdir, \"my_mnist_multiworker_logs\")\n",
    "    run_id = time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")\n",
    "    run_dir = os.path.join(root_logdir, run_id)\n",
    "    callbacks = [\n",
    "        keras.callbacks.TensorBoard(run_dir),\n",
    "        keras.callbacks.ModelCheckpoint(\"my_mnist_multiworker_model.h5\",\n",
    "                                        save_best_only=True),\n",
    "    ]\n",
    "else:\n",
    "    callbacks = []\n",
    "\n",
    "# Load and prepare the MNIST dataset\n",
    "(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.mnist.load_data()\n",
    "X_train_full = X_train_full[..., np.newaxis] / 255.\n",
    "X_valid, X_train = X_train_full[:5000], X_train_full[5000:]\n",
    "y_valid, y_train = y_train_full[:5000], y_train_full[5000:]\n",
    "\n",
    "with distribution.scope():\n",
    "    model = keras.models.Sequential([\n",
    "        keras.layers.Conv2D(filters=64, kernel_size=7, activation=\"relu\",\n",
    "                            padding=\"same\", input_shape=[28, 28, 1]),\n",
    "        keras.layers.MaxPooling2D(pool_size=2),\n",
    "        keras.layers.Conv2D(filters=128, kernel_size=3, activation=\"relu\",\n",
    "                            padding=\"same\"), \n",
    "        keras.layers.Conv2D(filters=128, kernel_size=3, activation=\"relu\",\n",
    "                            padding=\"same\"),\n",
    "        keras.layers.MaxPooling2D(pool_size=2),\n",
    "        keras.layers.Flatten(),\n",
    "        keras.layers.Dense(units=64, activation='relu'),\n",
    "        keras.layers.Dropout(0.5),\n",
    "        keras.layers.Dense(units=10, activation='softmax'),\n",
    "    ])\n",
    "    model.compile(loss=\"sparse_categorical_crossentropy\",\n",
    "                  optimizer=keras.optimizers.SGD(learning_rate=1e-2),\n",
    "                  metrics=[\"accuracy\"])\n",
    "\n",
    "model.fit(X_train, y_train, validation_data=(X_valid, y_valid),\n",
    "          epochs=10, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In a real world application, there would typically be a single worker per machine, but in this example we're running both workers on the same machine, so they will both try to use all the available GPU RAM (if this machine has a GPU), and this will likely lead to an Out-Of-Memory (OOM) error. To avoid this, we could use the `CUDA_VISIBLE_DEVICES` environment variable to assign a different GPU to each worker. Alternatively, we can simply disable GPU support, like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to start both workers, each in its own process, using Python's `subprocess` module. Before we start each process, we need to set the `TF_CONFIG` environment variable appropriately, changing only the task index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "cluster_spec = {\"worker\": [\"127.0.0.1:9901\", \"127.0.0.1:9902\"]}\n",
    "\n",
    "for index, worker_address in enumerate(cluster_spec[\"worker\"]):\n",
    "    os.environ[\"TF_CONFIG\"] = json.dumps({\n",
    "        \"cluster\": cluster_spec,\n",
    "        \"task\": {\"type\": \"worker\", \"index\": index}\n",
    "    })\n",
    "    subprocess.Popen(\"python my_mnist_multiworker_task.py\", shell=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it! Our TensorFlow cluster is now running, but we can't see it in this notebook because it's running in separate processes (but if you are running this notebook in Jupyter, you can see the worker logs in Jupyter's server logs).\n",
    "\n",
    "Since the chief (worker #0) is writing to TensorBoard, we use TensorBoard to view the training progress. Run the following cell, then click on the settings button (i.e., the gear icon) in the TensorBoard interface and check the \"Reload data\" box to make TensorBoard automatically refresh every 30s. Once the first epoch of training is finished (which may take a few minutes), and once TensorBoard refreshes, the SCALARS tab will appear. Click on this tab to view the progress of the model's training and validation accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The tensorboard extension is already loaded. To reload it, use:\n",
      "  %reload_ext tensorboard\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-5a57ac3228038e20\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-5a57ac3228038e20\");\n",
       "          const url = new URL(\"/\", window.location);\n",
       "          const port = 6006;\n",
       "          if (port) {\n",
       "            url.port = port;\n",
       "          }\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir=./my_mnist_multiworker_logs --port=6006"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it! Once training is over, the best checkpoint of the model will be available in the `my_mnist_multiworker_model.h5` file. You can load it using `keras.models.load_model()` and use it for predictions, as usual:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7, 2, 1])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow import keras\n",
    "\n",
    "model = keras.models.load_model(\"my_mnist_multiworker_model.h5\")\n",
    "Y_pred = model.predict(X_new)\n",
    "np.argmax(Y_pred, axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And that's all for today! Hope you found this useful. 😊"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exercise Solutions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. to 8.\n",
    "\n",
    "See Appendix A."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.\n",
    "_Exercise: Train a model (any model you like) and deploy it to TF Serving or Google Cloud AI Platform. Write the client code to query it using the REST API or the gRPC API. Update the model and deploy the new version. Your client code will now query the new version. Roll back to the first version._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please follow the steps in the <a href=\"#Deploying-TensorFlow-models-to-TensorFlow-Serving-(TFS)\">Deploying TensorFlow models to TensorFlow Serving</a> section above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10.\n",
    "_Exercise: Train any model across multiple GPUs on the same machine using the `MirroredStrategy` (if you do not have access to GPUs, you can use Colaboratory with a GPU Runtime and create two virtual GPUs). Train the model again using the `CentralStorageStrategy `and compare the training time._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please follow the steps in the [Distributed Training](#Distributed-Training) section above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 11.\n",
    "_Exercise: Train a small model on Google Cloud AI Platform, using black box hyperparameter tuning._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please follow the instructions on pages 716-717 of the book. You can also read [this documentation page](https://cloud.google.com/ai-platform/training/docs/hyperparameter-tuning-overview) and go through the example in this nice [blog post](https://towardsdatascience.com/how-to-do-bayesian-hyper-parameter-tuning-on-a-blackbox-model-882009552c6d) by Lak Lakshmanan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
